#!/usr/bin/env python3

import argparse
import atexit
from enum import IntEnum
import json
import os
import random
from subprocess import call, check_call, check_output, Popen, PIPE
import sys
from tempfile import NamedTemporaryFile
import threading
import time
from xml.dom import minidom
import xml.etree.ElementTree as etree


available_features = set()
device_name = None
server_address = None
skip_unit_tests = False
show_marionette_output = False

tests_dir = '/Library/Application Support/MWorks/Developer/tests'
test_file_dir = os.path.join(tests_dir, 'XML')

mworks_core_test_runner = os.path.join(tests_dir, 'MWorksCoreTestRunner')
marionette_test = os.path.join(tests_dir, ('MarionetteTest.app/Contents/'
                                           'MacOS/MarionetteTest'))

test_file_extensions = ('.xml', '.mwel', '.mwtest')
test_data_suffix = 'TESTDATA.xml'

error_tests = []
warning_tests = []

# The Python interpreter will exit with status 1 if there's an exception and
# 2 if there's an argument-parsing error.  Use status 3 to indicate that
# some tests failed with errors and status 4 to denote warnings without
# errors.
class ExitStatus(IntEnum):
    ERROR = 1
    USAGE_ERROR = 2
    TEST_FAILURE = 3
    TEST_WARNING = 4

server_lock = threading.Lock()


marionette_prep_applescript = """

-- Ensure that audio isn't muted
if (output muted of (get volume settings)) then
    set volume output muted false
end if

-- Ensure minimum output volume
set minVolume to 10
if (output volume of (get volume settings)) < minVolume then
    set volume output volume minVolume
end if

-- Activate MWServer
activate application "/Applications/MWServer.app"
repeat until application "/Applications/MWServer.app" is running
    delay 1
end repeat

"""


quit_server_applescript = """

quit application "/Applications/MWServer.app"
repeat while application "/Applications/MWServer.app" is running
    delay 1
end repeat

"""


def message(msg, *args):
    sys.stderr.write((msg % args) + '\n')


def warning(msg, *args):
    message('WARNING: ' + msg, *args)


def error(status, msg, *args):
    message('ERROR: ' + msg, *args)
    sys.exit(status)


def is_unit_test(name):
    return (b'::' in check_output(['/usr/bin/c++filt', name]))


def open_devicectl(args):
    args = ['/usr/bin/xcrun', 'devicectl', '--quiet'] + list(args)
    return Popen(args)


def run_devicectl(args):
    with NamedTemporaryFile() as output:
        args = ['--timeout', '60',
                '--json-output', output.name,
                ] + list(args)
        if open_devicectl(args).wait() != 0:
            error(ExitStatus.ERROR, 'devicectl command failed')
        return json.load(output)


def run_applescript(script):
    with server_lock:
        cmd = Popen(['/usr/bin/osascript'], stdin=PIPE)
        cmd.communicate(script.encode('utf-8'))
        if cmd.returncode != 0:
            error(ExitStatus.ERROR, 'osascript command failed')


def launch_mworks_app():
    args = ['device',
            'process',
            'launch',
            '--device', device_name,
            '--terminate-existing',
            '--console',
            'org.mworks-project.MWServer-iOS',
            ]
    process = open_devicectl(args)

    # Terminate the app on exit
    atexit.register(process.terminate)

    return process


class MWServerMemoryUsageMonitor(threading.Thread):

    ps_args = ('/bin/ps', 'x', '-c', '-o', 'pid,rss,comm')
    check_interval_seconds = 10
    rss_limit_gib = 2.0

    @classmethod
    def get_info(cls):
        for line in check_output(cls.ps_args, encoding='utf-8').split('\n'):
            if 'MWServer' in line:
                line = line.split()
                return {'pid': line[0], 'rss': int(line[1])}

    def run(self):
        last_rss_gib = 0.0
        while True:
            time.sleep(self.check_interval_seconds)
            info = self.get_info()
            if info:
                # ps returns rss in units of 1024 bytes.  Convert this value to
                # GiB.
                rss_gib = info['rss'] / (1024 * 1024)
                if rss_gib > self.rss_limit_gib:
                    signame = ('KILL' if last_rss_gib > self.rss_limit_gib
                               else 'TERM')
                    warning('MWServer real memory size (%g GiB) '
                            'exceeded limit (%g GiB); '
                            'killing process with signal %s',
                            rss_gib,
                            self.rss_limit_gib,
                            signame)
                    with server_lock:
                        check_call(['/bin/kill', '-s', signame, info['pid']])
                last_rss_gib = rss_gib


def run_unit_tests(test_names=None):
    with NamedTemporaryFile() as output:
        args = [mworks_core_test_runner, output.name]
        if test_names:
            args.extend(test_names)

        status = call(args)
        if status not in (0, 1):
            error(ExitStatus.ERROR,
                  '%s quit unexpectedly',
                  os.path.basename(mworks_core_test_runner))

        results = minidom.parse(output)

    failures = results.getElementsByTagName('FailedTests')[0]
    error_tests.extend([test.firstChild.data for test in
                        failures.getElementsByTagName('Name')])


def iter_marionette_tests(top_dir=test_file_dir):
    for root, dirs, files in os.walk(top_dir):
        for f in files:
            if ((os.path.splitext(f)[1] in test_file_extensions) and
                (not f.endswith(test_data_suffix))):
                yield os.path.join(root, f)


def run_marionette_tests(test_names=None):
    global server_address

    if device_name:
        # Launch the app.  We need to do this first, so that the connection is
        # already established when we try to get the hostname.
        mworks_app_process = launch_mworks_app()

        # Give devicectl a moment to establish the connection
        time.sleep(2)

        # Get the device's hostname
        for device in run_devicectl(['list', 'devices'])['result']['devices']:
            connection_props = device['connectionProperties']
            if (device['deviceProperties']['name'] == device_name and
                connection_props['tunnelState'] == 'connected'):
                server_address = connection_props['localHostnames'][0]
                break
        else:
            error(ExitStatus.ERROR,
                  'cannot determine hostname of requested device')

    if not test_names:
        test_file_iter = iter_marionette_tests
        if not server_address:
            # We're running all the tests, so this is probably an automated
            # run.  Ensure that the display is awake.
            check_call(['/usr/bin/caffeinate', '-u', '-t', '20'])
    else:
        def test_file_iter():
            for test in test_names:
                if not os.path.exists(test):
                    test = os.path.join(test_file_dir, test)
                if not os.path.isdir(test):
                    yield test
                else:
                    yield from iter_marionette_tests(test)

    # Randomize test order to increase our chances of catching bad
    # interactions between different core components
    all_test_names = list(test_file_iter())
    random.shuffle(all_test_names)

    if not server_address:
        # Ensure that server starts fresh
        run_applescript(quit_server_applescript)

        # Quit server on exit
        atexit.register(run_applescript, quit_server_applescript)

        # Install a monitor to prevent runaway memory usage in MWServer, which
        # started happening internal to NSTextView (used by the server console
        # window) in macOS 14.1.
        monitor = MWServerMemoryUsageMonitor(daemon=True)
        monitor.start()

    for test_num, test_file in enumerate(all_test_names):
        if not os.path.isfile(test_file):
            error(ExitStatus.ERROR, 'no such file: "%s"', test_file)

        args = [marionette_test, test_file]

        for test_data in ((os.path.splitext(test_file)[0] + '.' +
                           test_data_suffix),
                          os.path.join(os.path.dirname(test_file),
                                       test_data_suffix)):
            if os.path.isfile(test_data):
                args.append(test_data)

                # Determine feature requirements
                test_data_tree = etree.parse(test_data)
                required_features = \
                    set(e.attrib['name'] for e in
                        test_data_tree.iterfind('.//requirements/feature'))
                missing_features = required_features - available_features

                break
        else:
            missing_features = None

        if test_file.startswith(test_file_dir):
            test_name = test_file[len(test_file_dir)+1:]
        else:
            test_name = test_file

        print('(%d/%d) %s:' %  (test_num + 1, len(all_test_names), test_name),
              end = ('\n' if show_marionette_output else ' '))
        sys.stdout.flush()

        if missing_features:
            print('skipped (requires %s)' % ', '.join(missing_features))
            continue

        # For OS X 10.7: Remove any saved state, which can lead to a
        # "Restore Windows" dialog that hangs MarionetteTest until
        # it's dismissed
        saved_state_dir = os.path.expanduser(
            '~/Library/Saved Application State/'
            'org.mworks-project.MarionetteTest.savedState'
            )
        if os.path.isdir(saved_state_dir):
            check_call(['/bin/rm', '-Rf', saved_state_dir])

        if device_name:
            if mworks_app_process.poll() is not None:
                warning('MWorks app terminated with status %d; re-launching',
                        mworks_app_process.returncode)
                mworks_app_process = launch_mworks_app()
        elif not server_address:
            run_applescript(marionette_prep_applescript)

        env = os.environ.copy()
        if server_address:
            env['MARIONETTE_SERVER_ADDRESS'] = server_address

        cmd = Popen(args,
                    env = env,
                    stderr = (None if show_marionette_output else PIPE))
        output = cmd.communicate()[1]
        if cmd.returncode == 0:
            print('OK')
        else:
            if cmd.returncode == 2:
                print('warning')
                warning_tests.append(test_name)
            else:
                print('error')
                error_tests.append(test_name)
            for line in (output or b'').strip().split(b'\n'):
                print('    >', line.decode('utf-8'))


def run_tests(unit_test_names, marionette_test_names):
    global error_tests, warning_tests

    error_tests = []
    warning_tests = []

    if unit_test_names:
        run_unit_tests(unit_test_names)
    elif not (marionette_test_names or skip_unit_tests):
        run_unit_tests()

    if marionette_test_names:
        run_marionette_tests(marionette_test_names)
    elif not unit_test_names:
        run_marionette_tests()

    # Ensure consistent output order for test names, with unit tests listed
    # before marionette tests
    key = (lambda name: ((0, name) if is_unit_test(name) else (1, name)))
    error_tests.sort(key=key)
    warning_tests.sort(key=key)

    if not (error_tests or warning_tests):
        print('\nALL TESTS PASS')
    else:
        if warning_tests:
            print('\nWARNINGS:\n')
            for test_name in warning_tests:
                print('    %s' % test_name)
        # Print failures last, so they're easy to find at the end of the output
        if error_tests:
            print('\nFAILURES:\n')
            for test_name in error_tests:
                print('    %s' % test_name)
    print()


def main():
    global device_name, server_address, skip_unit_tests, show_marionette_output

    parser = argparse.ArgumentParser()
    parser.add_argument('--features',
                        default = '',
                        help = 'available features (as a comma-separated list)')
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--device-name',
                       help = ('name of connected device on which to launch '
                               'and test MWorks app'))
    group.add_argument('--server-address',
                       help = ('address of running MWServer or MWorks app '
                               'instance'))
    parser.add_argument('--app-bundle-path',
                        help = ('path to MWorks app bundle to install and test '
                                '(requires --device-name)'))
    parser.add_argument('--skip-unit-tests',
                        action = 'store_true',
                        help = 'do not run unit tests')
    parser.add_argument('--show-marionette-output',
                        action = 'store_true',
                        help = 'display output from all marionette tests')
    parser.add_argument('test_names',
                        metavar = 'test_name',
                        nargs = '*',
                        help = 'run only the requested test(s)')

    args = parser.parse_args()

    available_features.update(name.strip() for name in
                              args.features.split(',') if name.strip())
    device_name = args.device_name
    server_address = args.server_address
    app_bundle_path = args.app_bundle_path
    skip_unit_tests = args.skip_unit_tests
    show_marionette_output = args.show_marionette_output

    unit_test_names = []
    marionette_test_names = []

    for name in args.test_names:
        if is_unit_test(name):
            unit_test_names.append(name)
        else:
            marionette_test_names.append(name)

    if app_bundle_path:
        if not device_name:
            error(ExitStatus.USAGE_ERROR,
                  'cannot install app bundle: no device specified')
        else:
            run_devicectl(['device',
                           'install',
                           'app',
                           '--device', device_name,
                           app_bundle_path])

    run_tests(unit_test_names, marionette_test_names)

    if error_tests:
        sys.exit(ExitStatus.TEST_FAILURE)
    if warning_tests:
        sys.exit(ExitStatus.TEST_WARNING)


if __name__ == '__main__':
    main()
